############################################
# Experiment
############################################

import os
import sys
import pandas as pd
import numpy as np

import platform
# !python -V

sys.path.append('[PROJECT_PATH]/Code')
os.chdir('[PROJECT_PATH]/Code')

from utils import classify_tweets

# # Check CUDA
# import torch
# print(torch.cuda.is_available())

#----------------------
# Set Hyperparameters
#----------------------

MODEL_NAME = "Full"
LABELS = [
    "1", "2", "6", "7", 
    "9", "10", "11", "12", "21", "22", 
    "29", "30", "32", "33", "35", "36", "38", "39"
    ]

MODEL_TYPE = 'Multi-label'

if False:
    SUBSET = "txt"  
    RANDOM_SEED = 1

# SUBSET = "txt"
SUBSET = sys.argv[1]
print(f"Subset: {SUBSET}")

# RANDOM_SEED = 1
RANDOM_SEED = sys.argv[2]
print(f"Random seed: {RANDOM_SEED}")
RANDOM_SEED = int(RANDOM_SEED)

model_dir = f"[CLASSIFIERS_PATH]/{MODEL_NAME}/{MODEL_TYPE} {MODEL_NAME} {SUBSET} {RANDOM_SEED}"

os.makedirs(model_dir, exist_ok=True)


#----------------------
# Helper for data prep
#----------------------

def prepare_data(label_select, subset, random_seed):
    # Load codebook
    d_codebook = pd.read_csv(PATH_CODEBOOK)
    d_codebook['cat_label'] = d_codebook['category'] + ' - ' + d_codebook['label']
    codebook = [d_codebook.query(f"codebook_id == {lab}")['cat_label'].values[0] for lab in label_select]
    # Load coded tweets
    d = pd.read_csv(f'{PATH_DATA}/coded_label_wide_experiment.csv')
    # Shuffle the dataset
    d = d.sample(frac = 1, random_state = random_seed)
    # Get a dataset of unique tweet_id_str and split into training and validation sets
    tweet_id_str = d.tweet_id_str.unique()
    train_id, val_id = np.split(tweet_id_str, [int(0.8 * len(tweet_id_str))])

    # Split d into training and test set
    d_tr = d[d.tweet_id_str.isin(train_id)]
    d_va = d[d.tweet_id_str.isin(val_id)]
    # For the training set, subset depending on the required subset
    if subset == 'txt':
        d_tr = d_tr[d_tr['treatment'] == 0]
    else:
        d_tr = d_tr[d_tr['treatment'] == 1]
        
    # Bootstrap the training data (removed)
    # d_tr = d_tr.sample(frac = 1, replace = True, random_state = random_seed)

    # For the validation set, always use the treated group
    d_va = d_va[d_va['treatment'] == 1]
    # Clean labels
    d_tr['labels'] = list(d_tr[label_select].values)
    d_tr = d_tr[['tweet_id_str', 'text', 'labels']].copy()
    d_va['labels'] = list(d_va[label_select].values)
    d_va = d_va[['tweet_id_str', 'text', 'labels']].copy()
    # Return outputs
    return d_tr, d_va, codebook



#----------------------
# Load input data (simple)
#----------------------

PATH_DATA = '../Data'
PATH_CODEBOOK = '../Data/coded_codebook.csv'

# RANDOM_SEED = 5

# Get the dataset of interest
d_tr, d_va, codebook = prepare_data(LABELS, subset = SUBSET, random_seed = RANDOM_SEED)

print("ID training set: ", d_tr.tweet_id_str.unique(), len(d_tr.tweet_id_str.unique()))
print("ID validation set: ", d_va.tweet_id_str.unique(), len(d_va.tweet_id_str.unique()))
# Check if d_va.tweet_id_str is in d_tr.tweet_id_str for data leakage
print("Check for data leakage: ", d_va.tweet_id_str.isin(d_tr.tweet_id_str).sum())

d = (d_tr, d_va)

# print(d)
# print(codebook)

#----------------------------
# Train classifier (simple)
#----------------------------

worker = classify_tweets(
    transformer_spec = ['roberta', 'roberta-large'],
    model_name = f"{MODEL_TYPE} {MODEL_NAME} {SUBSET}",
    model_directory = model_dir,
    model_type = MODEL_TYPE,
    n_epochs = 20,
    data = d,
    codebook = codebook,
    random_seed = RANDOM_SEED,
    weighted = True
    )

worker.data_setup()

# print(worker.d_tr)
# print(d_tr)
# print(worker.d_va)
# print(d_va)

worker.setup_model(save_model_every_epoch = False, save_no_model = True)
worker.train_model()

